{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.7 64-bit ('nupic.research': conda)",
   "metadata": {
    "interpreter": {
     "hash": "2e19d0973aec7b9e5fc894021fb72a8408234a9c26a5fd32051dcbfb7e6930f2"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from os.path import expanduser\n",
    "sys.path.insert(0, expanduser(\"~/nta/nupic.research/projects/transformers\"))\n",
    "\n",
    "import torch\n",
    "from experiments import CONFIGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "from hashlib import blake2b\n",
    "\n",
    "from transformers import (\n",
    "    HfArgumentParser,\n",
    "    AutoTokenizer,\n",
    ")\n",
    "from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk\n",
    "\n",
    "from experiments import CONFIGS\n",
    "from run_args import CustomTrainingArguments, DataTrainingArguments, ModelArguments\n",
    "\n",
    "from run_utils import preprocess_datasets_mlm, hash_dataset_folder_name\n"
   ]
  },
  {
   "source": [
    "## Load the default arguments for our datasets"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_100k = CONFIGS[\"bert_100k\"]\n",
    "exp_parser = HfArgumentParser(\n",
    "    (ModelArguments, DataTrainingArguments, CustomTrainingArguments)\n",
    ")\n",
    "model_args, data_args, training_args = exp_parser.parse_dict(bert_100k)"
   ]
  },
  {
   "source": [
    "Modify default arguments for new little dataset"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_args.dataset_name = (\"wikipedia_plus_bookcorpus\", )\n",
    "data_args.dataset_config_name = (None, )\n",
    "data_args.max_seq_length = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'data_collator': 'DataCollatorForWholeWordMask',\n 'dataset_config_name': None,\n 'dataset_name': 'wikipedia_plus_bookcorpus',\n 'line_by_line': False,\n 'max_seq_length': 128,\n 'mlm_probability': 0.15,\n 'override_finetuning_results': False,\n 'overwrite_cache': False,\n 'pad_to_max_length': False,\n 'preprocessing_num_workers': None,\n 'reuse_tokenized_data': True,\n 'save_tokenized_data': True,\n 'task_name': None,\n 'task_names': [],\n 'tokenized_data_cache_dir': '/mnt/efs/results/preprocessed-datasets/text',\n 'train_file': None,\n 'validation_file': None,\n 'validation_split_percentage': 5}\n"
     ]
    }
   ],
   "source": [
    "pprint(data_args.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "source": [
    "## Load a fraction of Wikipedia and a fraction of Book Corpus\n",
    "These will be used to make a custom small dataset"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "Load 1% of Wikipedia for training and another 1% for validation"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Reusing dataset wikipedia (/mnt/efs/results/cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63)\n",
      "Reusing dataset wikipedia (/mnt/efs/results/cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63)\n"
     ]
    }
   ],
   "source": [
    "cache_dir = \"/mnt/efs/results/cache/huggingface/datasets/\"\n",
    "wiki_dataset_train = load_dataset(\n",
    "    \"wikipedia\", \"20200501.en\",\n",
    "    cache_dir=cache_dir,\n",
    "    split=\"train[:1%]\"\n",
    ")\n",
    "wiki_dataset_val = load_dataset(\n",
    "    \"wikipedia\", \"20200501.en\",\n",
    "    cache_dir=cache_dir,\n",
    "    split=f\"train[1%:2%]\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_dataset_train.remove_columns_(\"title\")\n",
    "wiki_dataset_val.remove_columns_(\"title\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Load 8% of Book Corpus for training and another 8% for validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Reusing dataset bookcorpus (/mnt/efs/results/cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/af844be26c089fb64810e9f2cd841954fd8bd596d6ddd26326e4c70e2b8c96fc)\n",
      "Reusing dataset bookcorpus (/mnt/efs/results/cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/af844be26c089fb64810e9f2cd841954fd8bd596d6ddd26326e4c70e2b8c96fc)\n"
     ]
    }
   ],
   "source": [
    "book_dataset_train = load_dataset(\"bookcorpus\", None, cache_dir=cache_dir, split=\"train[:8%]\")\n",
    "book_dataset_val = load_dataset(\"bookcorpus\", None, cache_dir=cache_dir, split=\"train[8:16%]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert wiki_dataset_train.features.type == \\\n",
    "    wiki_dataset_val.features.type == \\\n",
    "    book_dataset_train.features.type == \\\n",
    "    book_dataset_val.features.type\n"
   ]
  },
  {
   "source": [
    "Concatenate the datasets"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = DatasetDict()\n",
    "train_datasets = [wiki_dataset_train, book_dataset_train]\n",
    "validation_datasets = [wiki_dataset_val, book_dataset_val]\n",
    "\n",
    "datasets[\"train\"] = concatenate_datasets(train_datasets)\n",
    "datasets[\"validation\"] = concatenate_datasets(validation_datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets[\"train\"].info.description = \\\n",
    "\"\"\"\n",
    "A combination of Wikipedia plus Book Corpus. This uses the range from 0% to 1% of the train split of wikipedia, and 0% to 8% of the train split Book Corpus.\n",
    "\"\"\"\n",
    "\n",
    "datasets[\"validation\"].info.description = \\\n",
    "\"\"\"\n",
    "A combination of Wikipedia plus Book Corpus. This uses the range from 1% to 2% of the train split of wikipedia, and 8% to 16% of the train split Book Corpus.\n",
    "\"\"\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 5981122\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 5981122\n",
       "    })\n",
       "})"
      ]
     },
     "metadata": {},
     "execution_count": 34
    }
   ],
   "source": [
    "datasets"
   ]
  },
  {
   "source": [
    "## Tokenize and save the little dataset\n",
    "* It will be called `wikipedia_plus_bookcorpus`\n",
    "* max_seq_length=128"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Hashing dataset folder name 'wikipedia_plus_bookcorpus_None (max_seq_length=128)' to 'facb56894f4824388c354b52e13d2c8a421ca1f3'\nTokenized dataset cache folder: /mnt/efs/results/preprocessed-datasets/text/facb56894f4824388c354b52e13d2c8a421ca1f3\n"
     ]
    }
   ],
   "source": [
    "hashed_folder = hash_dataset_folder_name(data_args)\n",
    "\n",
    "dataset_path = os.path.join(\n",
    "    os.path.abspath(data_args.tokenized_data_cache_dir),\n",
    "    str(hashed_folder)\n",
    ")\n",
    "print(f\"Tokenized dataset cache folder: {dataset_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'data_collator': 'DataCollatorForWholeWordMask',\n 'dataset_config_name': (None,),\n 'dataset_name': ('wikipedia_plus_bookcorpus',),\n 'line_by_line': False,\n 'max_seq_length': 128,\n 'mlm_probability': 0.15,\n 'override_finetuning_results': False,\n 'overwrite_cache': False,\n 'pad_to_max_length': False,\n 'preprocessing_num_workers': None,\n 'reuse_tokenized_data': True,\n 'save_tokenized_data': True,\n 'task_name': None,\n 'task_names': [],\n 'tokenized_data_cache_dir': '/mnt/efs/results/preprocessed-datasets/text',\n 'train_file': None,\n 'validation_file': None,\n 'validation_split_percentage': 5}\n"
     ]
    }
   ],
   "source": [
    "pprint(data_args.__dict__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "NameError",
     "evalue": "name 'datasets' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-17-b347fbd8b1bd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcolumn_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumn_names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mtext_column_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"text\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m tokenizer_kwargs = dict(\n\u001b[1;32m      4\u001b[0m     \u001b[0mcache_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0muse_fast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_fast_tokenizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'datasets' is not defined"
     ]
    }
   ],
   "source": [
    "\n",
    "column_names = datasets[\"train\"].column_names\n",
    "text_column_name = \"text\"\n",
    "tokenizer_kwargs = dict(\n",
    "    cache_dir=model_args.cache_dir,\n",
    "    use_fast=model_args.use_fast_tokenizer,\n",
    "    revision=model_args.model_revision,\n",
    "    use_auth_token=True if model_args.use_auth_token else None,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_args.tokenizer_name,  # 'bert-base-cased'\n",
    "    **tokenizer_kwargs\n",
    ")\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (768 > 512). Running this sequence through the model will result in indexing errors\n",
      "Loading cached processed dataset at /mnt/efs/results/cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/cache-65fd888ecc1b6bf6.arrow\n",
      "Loading cached processed dataset at /mnt/efs/results/cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/4021357e28509391eab2f8300d9b689e7e8f3a877ebb3d354b01577d497ebc63/cache-90f2357a182dabe4.arrow\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=0.0, max=5982.0), HTML(value='')))",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "4c084d766ab74294aa64b30c51b8fb9b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=0.0, max=5982.0), HTML(value='')))",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "95c1584fcb5c4f69b513bd7a8c44a82a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n",
      "Saving tokenized dataset to /mnt/efs/results/preprocessed-datasets/text/facb56894f4824388c354b52e13d2c8a421ca1f3\n"
     ]
    }
   ],
   "source": [
    "tokenized_datasets = preprocess_datasets_mlm(\n",
    "    datasets, tokenizer, data_args,\n",
    "    column_names, text_column_name\n",
    ")\n",
    "\n",
    "print(f\"Saving tokenized dataset to {dataset_path}\")\n",
    "tokenized_datasets.save_to_disk(dataset_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'special_tokens_mask', 'token_type_ids'],\n",
       "        num_rows: 1132942\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'special_tokens_mask', 'token_type_ids'],\n",
       "        num_rows: 1133148\n",
       "    })\n",
       "})"
      ]
     },
     "metadata": {},
     "execution_count": 50
    }
   ],
   "source": [
    "# tokenized_datasets = load_from_disk(dataset_path)\n",
    "tokenized_datasets"
   ]
  },
  {
   "source": [
    "## Load Tokenized Dataset"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Hashing dataset folder name 'wikipedia_plus_bookcorpus_None (max_seq_length=128)' to 'facb56894f4824388c354b52e13d2c8a421ca1f3'\nTokenized dataset cache folder: /mnt/efs/results/preprocessed-datasets/text/facb56894f4824388c354b52e13d2c8a421ca1f3\n"
     ]
    }
   ],
   "source": [
    "hashed_folder = hash_dataset_folder_name(data_args)\n",
    "dataset_path = os.path.join(\n",
    "    os.path.abspath(data_args.tokenized_data_cache_dir),\n",
    "    str(hashed_folder)\n",
    ")\n",
    "print(f\"Tokenized dataset cache folder: {dataset_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_datasets = load_from_disk(dataset_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'special_tokens_mask', 'token_type_ids'],\n",
       "        num_rows: 1132942\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'special_tokens_mask', 'token_type_ids'],\n",
       "        num_rows: 1133148\n",
       "    })\n",
       "})"
      ]
     },
     "metadata": {},
     "execution_count": 23
    }
   ],
   "source": [
    "loaded_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}